package klog

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"gitee.com/klogsdk/klog-go-sdk/internal/apierr"
	"github.com/bwmarrin/snowflake"
	"github.com/golang/protobuf/proto"
	"github.com/pierrec/lz4/v4"
	"io/ioutil"
	"net/http"
	"strconv"
	"sync"
	"sync/atomic"
	"time"
	"unicode/utf8"
)

type worker struct {
	options *ClientOptions
	lock    *sync.Mutex

	rateLimiter     *RateLimit
	downSampler     *DownSampler
	buf             *LogGroupList
	bufSize         int
	bufLogCount     int
	curBatches      int32
	lastPushQueueAt time.Time
	stat            *Stat
	ch              chan *event
	stopped         bool
	wg              *sync.WaitGroup
	ctx             context.Context
	cancel          context.CancelFunc
	idGenerator     *snowflake.Node
}

type event struct {
	logGroupList *LogGroupList
}

func newWorker(options *ClientOptions, stat *Stat) *worker {
	ctx, cancel := context.WithCancel(context.Background())
	o := &worker{
		options: options,
		lock:    new(sync.Mutex),
		buf:     new(LogGroupList),
		stat:    stat,
		ch:      make(chan *event),
		wg:      new(sync.WaitGroup),
		ctx:     ctx,
		cancel:  cancel,
	}

	if options.RateLimit > 0 {
		o.rateLimiter = NewRateLimit(options.RateLimit, 10)
	}
	if options.DownSampleRate > 0 {
		o.downSampler = NewDownSampler(options.DownSampleRate)
	}
	o.idGenerator, _ = snowflake.NewNode(int64(options.MachineId))

	for i := 0; i < o.options.WorkerNum; i++ {
		o.wg.Add(1)
		ctx1, cancel1 := context.WithCancel(ctx)
		go o.worker(i, ctx1, cancel1)
	}
	go o.ticker()

	o.stat.setStartedAt(nowString())
	return o
}

func (o *worker) stop() {
	o.lock.Lock()
	defer o.lock.Unlock()
	o.stopped = true
	o.pushQueue(true)
	o.cancel()
	o.wg.Wait()
}

func (o *worker) flush(wait bool) {
	o.lock.Lock()
	defer o.lock.Unlock()
	o.pushQueue(wait)
}

func (o *worker) push(projectName, logPoolName, source, filename string, log *Log) error {
	o.lock.Lock()
	defer o.lock.Unlock()
	if o.stopped {
		return fmt.Errorf("klog.worker.push: could not push because of being stopped")
	}

	if o.downSampler != nil && !o.downSampler.Ok() {
		// 被降采样，抛弃这条日志
		o.stat.addDownSampledLogs(1)
		o.logger().Debugf("klog.worker: down sampled, drop 1 log, project=%s, pool=%s", projectName, logPoolName)
		return nil
	}

	// 如果限速
	if o.rateLimiter != nil {
		o.rateLimiter.Wait()
	}

	// 检查这条日志
	if err := o.checkLog(log); err != nil {
		o.stat.addErroredLogs(1)
		return err
	}

	size := proto.Size(log)
	if size > MaxLogSize {
		return apierr.New(MaxLogSizeExceeded, fmt.Sprintf("the size of this log[%d] should not be greater than is %d", size, MaxLogSize), nil)
	}

	if size+o.bufSize > MaxLogGroupSize {
		// 这条log与buf中的log size之和，超过限制，需要先把buf中的发送出去
		o.pushQueue(false)
	}

	// 处理这条log
	o.addBuf(projectName, logPoolName, source, filename, log, size)
	if o.bufSize >= MaxLogGroupSize || o.bufLogCount >= MaxBulkSize {
		o.pushQueue(false)
	}
	return nil
}

func (o *worker) addBuf(projectName, logPoolName, source, filename string, log *Log, size int) {
	o.bufSize += size
	o.bufLogCount += 1
	for i := 0; i < len(o.buf.LogGroupList); i++ {
		if o.buf.LogGroupList[i].GetProject() == projectName &&
			o.buf.LogGroupList[i].GetPool() == logPoolName &&
			o.buf.LogGroupList[i].GetSource() == source &&
			o.buf.LogGroupList[i].GetFilename() == filename {
			o.buf.LogGroupList[i].Logs = append(o.buf.LogGroupList[i].Logs, log)
			return
		}
	}
	o.buf.LogGroupList = append(o.buf.LogGroupList,
		&LogGroup{
			Logs:     []*Log{log},
			Project:  &projectName,
			Pool:     &logPoolName,
			Source:   &source,
			Filename: &filename,
		})
}

func (o *worker) pushQueue(needAck bool) {
	if o.bufSize > 0 {
		atomic.AddInt32(&o.curBatches, 1)

		o.lastPushQueueAt = time.Now()
		o.ch <- &event{
			logGroupList: o.buf,
		}
		o.buf = new(LogGroupList)
		o.bufSize = 0
		o.bufLogCount = 0
	}

	if needAck {
		for {
			if atomic.LoadInt32(&o.curBatches) == 0 {
				return
			}
			time.Sleep(time.Duration(1) * time.Millisecond)
		}
	}
	return
}

func (o *worker) ticker() {
	ticker := time.NewTicker(time.Millisecond * time.Duration(200))
	preCount := o.bufLogCount
	for range ticker.C {
		o.lock.Lock()
		if time.Now().Sub(o.lastPushQueueAt) > time.Duration(2)*time.Second && o.bufLogCount > 0 && preCount == o.bufLogCount {
			o.pushQueue(false)
		}
		preCount = o.bufLogCount
		o.lock.Unlock()
	}
}

type responseMessage struct {
	ErrorCode    string `json:"ErrorCode"`
	ErrorMessage string `json:"ErrorMessage"`
	DroppedLogs  int    `json:"DroppedLogs"`
}

func (o *worker) worker(number int, ctx context.Context, cancel context.CancelFunc) {
	defer cancel()
	defer o.wg.Done()
	var resMsg *responseMessage

	for {
		select {
		case <-ctx.Done():
			// 收到停止信号
			o.logger().Infof("klog.worker[%d]: cancel received", number)
			return
		case event := <-o.ch:
			count := 0
			for i := range event.logGroupList.LogGroupList {
				count += len(event.logGroupList.LogGroupList[i].Logs)
			}

			// data
			data, err := proto.Marshal(event.logGroupList)
			event = nil
			if err == nil {
				if o.options.CompressMethod == CompressMethodLz4 {
					data, _ = compressLz4(data)
				}

				// url
				url := fmt.Sprintf("%s/PutLogsM?DropIfPoolNotExists=%s",
					o.options.Endpoint, strconv.FormatBool(o.options.DropIfPoolNotExists))

				// loop for retry
				var retried int
				var code int
				for {
					o.stat.setCurRetried(retried)
					o.stat.setLastLogs(count)

					o.logger().Debugf("klog.worker[%d]: sending %d logs", number, count)

					if code, resMsg, err = o.send(url, data); err == nil {
						if code == 200 {
							// 发送成功
							o.stat.setLastSucceededAt(nowString())
							o.stat.addSentBatches(1)
							o.stat.addSentLogs(count - resMsg.DroppedLogs)
							o.stat.addDroppedLogs(resMsg.DroppedLogs)
							if resMsg.DroppedLogs > 0 {
								o.logger().Debugf("klog.worker[%d]: %d logs dropped because of log pool being not exists",
									number, resMsg.DroppedLogs)
							}
							break
						} else {
							// 返回错误码
							if resMsg != nil && resMsg.ErrorCode != "" {
								err = apierr.New(resMsg.ErrorCode, resMsg.ErrorMessage, nil)
							} else {
								err = apierr.New(HttpError, strconv.Itoa(code), nil)
							}
						}
					}

					o.logger().Warnf("klog.worker[%d]: send error, err=%s", number, err.Error())

					o.stat.setLastError(err.Error())
					if retried >= o.options.MaxRetries {
						o.stat.addFailedBatches(1)
						o.stat.addFailedLogs(count)
						o.logger().Warnf("klog.worker[%d]: max retries reached, stop retry, %d logs failed", number, count)
						break
					}

					o.logger().Infof("klog.worker[%d]: sleep then retry", number)

					// 其他问题都需要重试
					timer := makeRandomTimer(retried)
					select {
					case <-timer.C:
						retried++
						o.stat.addRetried(1)
						o.stat.setLastRetried(retried)
						o.stat.setLastRetriedAt(nowString())
						continue
					case <-ctx.Done():
						// 收到停止信号
						o.logger().Infof("klog.worker[%d]: cancel received, stop retry", number)
						atomic.AddInt32(&o.curBatches, -1)
						return
					}
				}
			} else {
				o.stat.addFailedBatches(1)
				o.stat.addFailedLogs(count)
				o.logger().Errorf("klog.worker[%d]: proto.Marshal error: %s", number, err.Error())
			}
			atomic.AddInt32(&o.curBatches, -1)
		}
	}
}

// 发送请求
func (o *worker) send(url string, data []byte) (code int, resMsg *responseMessage, err error) {
	var resp *http.Response
	var respData []byte
	resMsg = new(responseMessage)

	// request
	req, _ := http.NewRequest("POST", url, bytes.NewReader(data))
	req.Header.Set("Content-Length", strconv.Itoa(len(data)))
	req.Header.Set("Content-Type", "application/x-protobuf")
	req.Header.Set("x-klog-api-version", o.options.AppName)
	req.Header.Set("x-klog-signature-method", "hmac-sha1")
	req.Header.Set("x-klog-compress-type", "lz4")
	req.Header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
	if err = signatureV2(req, o.options.Credentials); err != nil {
		return 0, nil, apierr.New(SignatureError, err.Error(), nil)
	}

	if resp, err = o.options.HTTPClient.Do(req); err != nil {
		return 0, nil, apierr.New(HttpError, err.Error(), nil)
	}

	code = resp.StatusCode
	if resp.ContentLength != 0 {
		defer resp.Body.Close()
		if respData, err = ioutil.ReadAll(resp.Body); err != nil {
			return code, resMsg, apierr.New(HttpError, fmt.Sprintf("reading response body error, code=%d, error=%s", code, err.Error()), nil)
		} else if len(respData) > 0 {
			if err = json.Unmarshal(respData, &resMsg); err != nil {
				return code, nil, apierr.New(HttpError, fmt.Sprintf("unmarshal response body error, code=%d, error=%s", code, err.Error()), nil)
			}
		}
	}
	return code, resMsg, nil
}

func compressLz4(data []byte) ([]byte, error) {
	buf := new(bytes.Buffer)
	z := lz4.NewWriter(buf)
	_ = z.Apply(lz4.BlockSizeOption(lz4.Block1Mb))
	_, err := z.Write(data)
	if err != nil {
		return nil, apierr.New(CompressLz4Error, fmt.Sprintf("failed to write, error=%s", err.Error()), nil)
	}
	err = z.Close()
	if err != nil {
		return nil, apierr.New(CompressLz4Error, fmt.Sprintf("failed to close, error=%s", err.Error()), nil)
	}
	return buf.Bytes(), nil
}

func (o *worker) logger() Logger {
	return o.options.Logger
}

func (o *worker) checkLog(log *Log) error {
	idIdx := -1
	contents := log.GetContents()
	for i := 0; i < len(contents); i++ {
		if !utf8.ValidString(contents[i].Key) {
			return apierr.New(InvalidUtf8InKey, fmt.Sprintf("invalid UTF-8 in key"), nil)
		}

		if !utf8.ValidString(contents[i].Value) {
			return apierr.New(InvalidUtf8InValue, fmt.Sprintf("invalid UTF-8 in value"), nil)
		}

		if len([]byte(contents[i].Key)) > MaxKeySize {
			return apierr.New(MaxKeySizeExceeded, fmt.Sprintf("the size[%d] of a key should not be greater than %d", len([]byte(contents[i].Key)), MaxKeySize), nil)
		}

		if len([]byte(contents[i].Value)) > MaxValueSize {
			return apierr.New(MaxValueSizeExceeded, fmt.Sprintf("the size[%d] of a value should not be greater than %d", len([]byte(contents[i].Value)), MaxValueSize), nil)
		}

		if contents[i].Key == "__id__" {
			idIdx = i
		}
	}

	if idIdx >= 0 {
		if contents[idIdx].Value == "" {
			contents[idIdx].Value = o.idGenerator.Generate().String()
		}
	} else {
		contents = append(contents, &Log_Content{
			Key:   "__id__",
			Value: o.idGenerator.Generate().String(),
		})
	}

	if len(contents) > MaxKeyCount {
		return apierr.New(MaxKeyCountExceeded, fmt.Sprintf("the amount[%d] of keys in one log should not be greater than %d", len(contents), MaxKeyCount), nil)
	}
	log.Contents = contents
	return nil
}
